#include "utils.h"
#include <thrust/execution_policy.h>
#include <thrust/scan.h>

const int MAX_FEAT_SIZE = 128;

template <typename scalar_t>
__global__ void composite_train_fw_kernel(
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        sigmas,
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits,
                                      size_t>
        rgbs,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        deltas,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        ts,
    const torch::PackedTensorAccessor64<int64_t, 2, torch::RestrictPtrTraits>
        rays_a,
    const scalar_t T_threshold,
    torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits>
        total_samples,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>
        opacity,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>
        depth,
    torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>
        rgb,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>
        ws) {
  const int n = blockIdx.x * blockDim.x + threadIdx.x;
  if (n >= opacity.size(0))
    return;

  const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1],
            N_samples = rays_a[n][2];

  // front to back compositing
  int samples = 0;
  scalar_t T = 1.0f;

  while (samples < N_samples) {
    const int s = start_idx + samples;
    const scalar_t a = 1.0f - __expf(-sigmas[s] * deltas[s]);
    const scalar_t w = a * T; // weight of the sample point

    // support feature rendering
    for (int i = 0; i < rgbs.size(1); ++i)
      rgb[ray_idx][i] += w * rgbs[s][i];
    depth[ray_idx] += w * ts[s];
    opacity[ray_idx] += w;
    ws[s] = w;
    T *= 1.0f - a;

    if (T <= T_threshold)
      break; // ray has enough opacity
    samples++;
  }
  total_samples[ray_idx] = samples;
}

std::vector<torch::Tensor>
composite_train_fw_cu(const torch::Tensor sigmas, const torch::Tensor rgbs,
                      const torch::Tensor deltas, const torch::Tensor ts,
                      const torch::Tensor rays_a, const float T_threshold) {
  const int N_rays = rays_a.size(0), N = sigmas.size(0);

  auto opacity = torch::zeros({N_rays}, sigmas.options());
  auto depth = torch::zeros({N_rays}, sigmas.options());
  auto rgb = torch::zeros({N_rays, rgbs.size(1)}, sigmas.options());
  auto ws = torch::zeros({N}, sigmas.options());
  auto total_samples = torch::zeros(
      {N_rays}, torch::dtype(torch::kLong).device(sigmas.device()));

  const int threads = 256, blocks = (N_rays + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      sigmas.scalar_type(), "composite_train_fw_cu", ([&] {
        composite_train_fw_kernel<scalar_t><<<blocks, threads>>>(
            sigmas.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                   size_t>(),
            rgbs.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits,
                                 size_t>(),
            deltas.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                   size_t>(),
            ts.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
            rays_a.packed_accessor64<int64_t, 2, torch::RestrictPtrTraits>(),
            T_threshold,
            total_samples
                .packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
            opacity.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                    size_t>(),
            depth.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                  size_t>(),
            rgb.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits,
                                size_t>(),
            ws.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                               size_t>());
      }));

  return {total_samples, opacity, depth, rgb, ws};
}

template <typename scalar_t>
__global__ void composite_train_bw_kernel(
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        dL_dopacity,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        dL_ddepth,
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits,
                                      size_t>
        dL_drgb,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        dL_dws,
    scalar_t *__restrict__ dL_dws_times_ws,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        sigmas,
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits,
                                      size_t>
        rgbs,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        deltas,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        ts,
    const torch::PackedTensorAccessor64<int64_t, 2, torch::RestrictPtrTraits>
        rays_a,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        opacity,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>
        depth,
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits,
                                      size_t>
        rgb,
    const scalar_t T_threshold,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>
        dL_dsigmas,
    torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>
        dL_drgbs) {
  const int n = blockIdx.x * blockDim.x + threadIdx.x;
  if (n >= opacity.size(0))
    return;

  const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1],
            N_samples = rays_a[n][2];

  // front to back compositing
  int samples = 0;
  scalar_t O = opacity[ray_idx], D = depth[ray_idx];
  scalar_t T = 1.0f, d = 0.0f;
  float rgb_tmp[MAX_FEAT_SIZE];
  memset(rgb_tmp, 0.0f, MAX_FEAT_SIZE);

  // compute prefix sum of dL_dws * ws
  // [a0, a1, a2, a3, ...] -> [a0, a0+a1, a0+a1+a2, a0+a1+a2+a3, ...]
  thrust::inclusive_scan(thrust::device, dL_dws_times_ws + start_idx,
                         dL_dws_times_ws + start_idx + N_samples,
                         dL_dws_times_ws + start_idx);
  scalar_t dL_dws_times_ws_sum = dL_dws_times_ws[start_idx + N_samples - 1];

  while (samples < N_samples) {
    const int s = start_idx + samples;
    const scalar_t a = 1.0f - __expf(-sigmas[s] * deltas[s]);
    const scalar_t w = a * T;

    d += w * ts[s];
    T *= 1.0f - a;

    // compute gradients by math...
    for (int i = 0; i < dL_drgb.size(1); ++i) {
      rgb_tmp[i] += w * rgbs[s][i];
      dL_drgbs[s][i] = dL_drgb[ray_idx][i] * w;
    }

    scalar_t tmp = 0.0f;
    for (int i = 0; i < dL_drgb.size(1); ++i)
      tmp += dL_drgb[ray_idx][i] *
             (rgbs[s][i] * T - (rgb[ray_idx][i] - rgb_tmp[i]));

    dL_dsigmas[s] =
        deltas[s] *
        (tmp +                                        // gradients from rgb
         dL_dopacity[ray_idx] * (1 - O) +             // gradient from opacity
         dL_ddepth[ray_idx] * (ts[s] * T - (D - d)) + // gradient from depth
         T * dL_dws[s] -
         (dL_dws_times_ws_sum - dL_dws_times_ws[s]) // gradient from ws
        );

    if (T <= T_threshold)
      break; // ray has enough opacity
    samples++;
  }
}

std::vector<torch::Tensor>
composite_train_bw_cu(const torch::Tensor dL_dopacity,
                      const torch::Tensor dL_ddepth,
                      const torch::Tensor dL_drgb, const torch::Tensor dL_dws,
                      const torch::Tensor sigmas, const torch::Tensor rgbs,
                      const torch::Tensor ws, const torch::Tensor deltas,
                      const torch::Tensor ts, const torch::Tensor rays_a,
                      const torch::Tensor opacity, const torch::Tensor depth,
                      const torch::Tensor rgb, const float T_threshold) {
  const int N = sigmas.size(0), N_rays = rays_a.size(0);

  auto dL_dsigmas = torch::zeros({N}, sigmas.options());
  auto dL_drgbs = torch::zeros({N, rgb.size(1)}, sigmas.options());

  auto dL_dws_times_ws = dL_dws * ws; // auxiliary input

  const int threads = 256, blocks = (N_rays + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      sigmas.scalar_type(), "composite_train_bw_cu", ([&] {
        composite_train_bw_kernel<scalar_t><<<blocks, threads>>>(
            dL_dopacity.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                        size_t>(),
            dL_ddepth.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                      size_t>(),
            dL_drgb.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits,
                                    size_t>(),
            dL_dws.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                   size_t>(),
            dL_dws_times_ws.data_ptr<scalar_t>(),
            sigmas.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                   size_t>(),
            rgbs.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits,
                                 size_t>(),
            deltas.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                   size_t>(),
            ts.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
            rays_a.packed_accessor64<int64_t, 2, torch::RestrictPtrTraits>(),
            opacity.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                    size_t>(),
            depth.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                  size_t>(),
            rgb.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits,
                                size_t>(),
            T_threshold,
            dL_dsigmas.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                       size_t>(),
            dL_drgbs.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits,
                                     size_t>());
      }));

  return {dL_dsigmas, dL_drgbs};
}

template <typename scalar_t>
__global__ void composite_test_fw_kernel(
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits,
                                      size_t>
        sigmas,
    const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits,
                                      size_t>
        rgbs,
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits,
                                      size_t>
        deltas,
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits,
                                      size_t>
        ts,
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits,
                                      size_t>
        hits_t,
    torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits>
        alive_indices,
    const scalar_t T_threshold,
    const torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits>
        N_eff_samples,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>
        opacity,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>
        depth,
    torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>
        rgb) {
  const int n = blockIdx.x * blockDim.x + threadIdx.x;
  if (n >= alive_indices.size(0))
    return;

  if (N_eff_samples[n] == 0) { // no hit
    alive_indices[n] = -1;
    return;
  }

  const size_t r = alive_indices[n]; // ray index

  // front to back compositing
  int s = 0;
  scalar_t T = 1 - opacity[r];

  while (s < N_eff_samples[n]) {
    const scalar_t a = 1.0f - __expf(-sigmas[n][s] * deltas[n][s]);
    const scalar_t w = a * T;

    for (int i = 0; i < rgbs.size(2); ++i)
      rgb[r][i] += w * rgbs[n][s][i];

    depth[r] += w * ts[n][s];
    opacity[r] += w;
    T *= 1.0f - a;

    if (T <= T_threshold) { // ray has enough opacity
      alive_indices[n] = -1;
      break;
    }
    s++;
  }
}

void composite_test_fw_cu(const torch::Tensor sigmas, const torch::Tensor rgbs,
                          const torch::Tensor deltas, const torch::Tensor ts,
                          const torch::Tensor hits_t,
                          torch::Tensor alive_indices, const float T_threshold,
                          const torch::Tensor N_eff_samples,
                          torch::Tensor opacity, torch::Tensor depth,
                          torch::Tensor rgb) {
  const int N_rays = alive_indices.size(0);

  const int threads = 256, blocks = (N_rays + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      sigmas.scalar_type(), "composite_test_fw_cu", ([&] {
        composite_test_fw_kernel<scalar_t><<<blocks, threads>>>(
            sigmas.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits,
                                   size_t>(),
            rgbs.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits,
                                 size_t>(),
            deltas.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits,
                                   size_t>(),
            ts.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
            hits_t.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits,
                                   size_t>(),
            alive_indices
                .packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
            T_threshold,
            N_eff_samples.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),
            opacity.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                    size_t>(),
            depth.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits,
                                  size_t>(),
            rgb.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits,
                                size_t>());
      }));
}
